{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp input_fn\n",
    "import os\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Function to Create Datasets\n",
    "\n",
    "Function to create datasets to train, eval and predict."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from typing import List, Union, Dict\n",
    "import json\n",
    "from loguru import logger\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from m3tl.params import Params\n",
    "from m3tl.read_write_tfrecord import read_tfrecord, write_tfrecord\n",
    "from m3tl.special_tokens import PREDICT, TRAIN\n",
    "from m3tl.utils import infer_shape_and_type_from_dict, get_is_pyspark\n",
    "from m3tl.preproc_decorator import preprocessing_fn\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Eval Dataset\n",
    "We can get train and eval dataset by passing a problem assigned params and mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "def element_length_func(yield_dict: Dict[str, tf.Tensor]):  # pragma: no cover\n",
    "    input_ids_keys = [k for k in yield_dict.keys() if 'input_ids' in k]\n",
    "    max_length = tf.reduce_sum([tf.shape(yield_dict[k])[0]\n",
    "                               for k in input_ids_keys])\n",
    "    return max_length\n",
    "\n",
    "\n",
    "def train_eval_input_fn(params: Params, mode=TRAIN) -> tf.data.Dataset:\n",
    "    '''\n",
    "    This function will write and read tf record for training\n",
    "    and evaluation.\n",
    "\n",
    "    Arguments:\n",
    "        params {Params} -- Params objects\n",
    "\n",
    "    Keyword Arguments:\n",
    "        mode {str} -- ModeKeys (default: {TRAIN})\n",
    "\n",
    "    Returns:\n",
    "        tf Dataset -- Tensorflow dataset\n",
    "    '''\n",
    "    write_tfrecord(params=params)\n",
    "    \n",
    "    # reading with pyspark is not supported\n",
    "    if get_is_pyspark():\n",
    "        return\n",
    "\n",
    "    dataset_dict = read_tfrecord(params=params, mode=mode)\n",
    "\n",
    "    # make sure the order is correct\n",
    "    dataset_dict_keys = list(dataset_dict.keys())\n",
    "    dataset_list = [dataset_dict[key] for key in dataset_dict_keys]\n",
    "    sample_prob_dict = params.calculate_data_sampling_prob()\n",
    "    weight_list = [\n",
    "        sample_prob_dict[key]\n",
    "        for key in dataset_dict_keys\n",
    "    ]\n",
    "    \n",
    "    logger.info('sampling weights: ')\n",
    "    logger.info(json.dumps(params.problem_sampling_weight_dict, indent=4))\n",
    "    # for problem_chunk_name, weight in params.problem_sampling_weight_dict.items():\n",
    "    #     logger.info('{0}: {1}'.format(problem_chunk_name, weight))\n",
    "\n",
    "    dataset = tf.data.experimental.sample_from_datasets(\n",
    "        datasets=dataset_list, weights=weight_list)\n",
    "    options = tf.data.Options()\n",
    "    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA\n",
    "    dataset = dataset.with_options(options)\n",
    "\n",
    "    if mode == TRAIN:\n",
    "        dataset = dataset.shuffle(params.shuffle_buffer)\n",
    "\n",
    "    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
    "    if params.dynamic_padding:\n",
    "        dataset = dataset.apply(\n",
    "            tf.data.experimental.bucket_by_sequence_length(\n",
    "                element_length_func=element_length_func,\n",
    "                bucket_batch_sizes=params.bucket_batch_sizes,\n",
    "                bucket_boundaries=params.bucket_boundaries\n",
    "            ))\n",
    "    else:\n",
    "        first_example = next(dataset.as_numpy_iterator())\n",
    "        output_shapes, _ = infer_shape_and_type_from_dict(first_example)\n",
    "\n",
    "        if mode == TRAIN:\n",
    "            dataset = dataset.padded_batch(params.batch_size, output_shapes)\n",
    "        else:\n",
    "            dataset = dataset.padded_batch(params.batch_size*2, output_shapes)\n",
    "\n",
    "    return dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-15 17:19:05.812 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "2021-06-15 17:19:05.812 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "2021-06-15 17:19:05.813 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_cls, problem type: cls\n",
      "2021-06-15 17:19:05.813 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_masklm, problem type: masklm\n",
      "2021-06-15 17:19:05.814 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_regression, problem type: regression\n",
      "2021-06-15 17:19:05.814 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "2021-06-15 17:19:05.815 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem weibo_premask_mlm, problem type: premask_mlm\n",
      "2021-06-15 17:19:05.815 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem fake_contrastive_learning, problem type: contrastive_learning\n",
      "2021-06-15 17:19:05.816 | WARNING  | m3tl.base_params:assign_problem:634 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.\n",
      "2021-06-15 17:19:05.818 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-15 17:19:11.410 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "2021-06-15 17:19:11.410 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from m3tl.test_base import TestBase\n",
    "import m3tl\n",
    "import shutil\n",
    "import numpy as np\n",
    "test_base = TestBase()\n",
    "test_base.params.assign_problem(\n",
    "    'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm')\n",
    "params = test_base.params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-15 17:19:11.733 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-15 17:19:11.740 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord\n",
      "2021-06-15 17:19:11.771 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-15 17:19:11.777 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord\n",
      "2021-06-15 17:19:11.803 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/train_00000.tfrecord\n",
      "2021-06-15 17:19:11.827 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/eval_00000.tfrecord\n",
      "2021-06-15 17:19:11.905 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/train_00000.tfrecord\n",
      "2021-06-15 17:19:11.955 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/eval_00000.tfrecord\n",
      "2021-06-15 17:19:12.697 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: \n",
      "2021-06-15 17:19:12.698 | INFO     | __main__:train_eval_input_fn:38 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner\": 0.3333333333333333,\n",
      "    \"weibo_fake_multi_cls\": 0.3333333333333333,\n",
      "    \"weibo_masklm\": 0.3333333333333333\n",
      "}\n",
      "2021-06-15 17:19:13.141 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: \n",
      "2021-06-15 17:19:13.142 | INFO     | __main__:train_eval_input_fn:38 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner\": 0.3333333333333333,\n",
      "    \"weibo_fake_multi_cls\": 0.3333333333333333,\n",
      "    \"weibo_masklm\": 0.3333333333333333\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "train_dataset = train_eval_input_fn(\n",
    "    params=params, mode=m3tl.TRAIN)\n",
    "eval_dataset = train_eval_input_fn(\n",
    "    params=params, mode=m3tl.EVAL\n",
    ")\n",
    "\n",
    "_ = next(train_dataset.as_numpy_iterator())\n",
    "_ = next(eval_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-15 17:19:14.003 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-15 17:19:14.010 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord\n",
      "2021-06-15 17:19:14.041 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-15 17:19:14.047 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord\n",
      "2021-06-15 17:19:14.072 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/train_00000.tfrecord\n",
      "2021-06-15 17:19:14.098 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_fake_multi_cls/eval_00000.tfrecord\n",
      "2021-06-15 17:19:14.180 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/train_00000.tfrecord\n",
      "2021-06-15 17:19:14.231 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmp2afsw8rx/weibo_masklm/eval_00000.tfrecord\n",
      "2021-06-15 17:19:14.518 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: \n",
      "2021-06-15 17:19:14.519 | INFO     | __main__:train_eval_input_fn:38 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner\": 0.3333333333333333,\n",
      "    \"weibo_fake_multi_cls\": 0.3333333333333333,\n",
      "    \"weibo_masklm\": 0.3333333333333333\n",
      "}\n",
      "2021-06-15 17:19:15.077 | INFO     | __main__:train_eval_input_fn:37 - sampling weights: \n",
      "2021-06-15 17:19:15.078 | INFO     | __main__:train_eval_input_fn:38 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner\": 0.3333333333333333,\n",
      "    \"weibo_fake_multi_cls\": 0.3333333333333333,\n",
      "    \"weibo_masklm\": 0.3333333333333333\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "# dynamic_padding disabled\n",
    "# have to remove existing tfrecord\n",
    "shutil.rmtree(test_base.tmpfiledir)\n",
    "test_base.params.dynamic_padding = False\n",
    "train_dataset = train_eval_input_fn(\n",
    "    params=test_base.params, mode=m3tl.TRAIN)\n",
    "eval_dataset = train_eval_input_fn(\n",
    "    params=test_base.params, mode=m3tl.EVAL\n",
    ")\n",
    "_ = next(train_dataset.as_numpy_iterator())\n",
    "_ = next(eval_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict Dataset\n",
    "\n",
    "We can create a predict dataset by passing list/generator of inputs and problem assigned params."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def predict_input_fn(input_file_or_list: Union[str, List[str]],\n",
    "                     params: Params,\n",
    "                     mode=PREDICT,\n",
    "                     labels_in_input=False) -> tf.data.Dataset:\n",
    "    '''Input function that takes a file path or list of string and\n",
    "    convert it to tf.dataset\n",
    "\n",
    "    Example:\n",
    "        predict_fn = lambda: predict_input_fn('test.txt', params)\n",
    "        pred = estimator.predict(predict_fn)\n",
    "\n",
    "    Arguments:\n",
    "        input_file_or_list {str or list} -- file path or list of string\n",
    "        params {Params} -- Params object\n",
    "\n",
    "    Keyword Arguments:\n",
    "        mode {str} -- ModeKeys (default: {PREDICT})\n",
    "\n",
    "    Returns:\n",
    "        tf dataset -- tf dataset\n",
    "    '''\n",
    "\n",
    "    # if is string, treat it as path to file\n",
    "    if isinstance(input_file_or_list, str):\n",
    "        inputs = open(input_file_or_list, 'r', encoding='utf8')\n",
    "    else:\n",
    "        inputs = input_file_or_list\n",
    "\n",
    "    # ugly wrapping\n",
    "    def gen():\n",
    "        @preprocessing_fn\n",
    "        def gen_wrapper(params, mode):\n",
    "            return inputs\n",
    "        return gen_wrapper(params, mode)\n",
    "\n",
    "    first_dict = next(gen())\n",
    "\n",
    "    output_shapes, output_type = infer_shape_and_type_from_dict(first_dict)\n",
    "    dataset = tf.data.Dataset.from_generator(\n",
    "        gen, output_types=output_type, output_shapes=output_shapes)\n",
    "\n",
    "    dataset = dataset.padded_batch(\n",
    "        params.batch_size,\n",
    "        output_shapes\n",
    "    )\n",
    "    # dataset = dataset.batch(config.batch_size*2)\n",
    "\n",
    "    return dataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single modal inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from m3tl.utils import set_phase\n",
    "from m3tl.special_tokens import PREDICT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-15 17:19:16.349 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n"
     ]
    }
   ],
   "source": [
    "\n",
    "set_phase(PREDICT)\n",
    "single_dataset = predict_input_fn(\n",
    "    ['this is a test']*5, params=params)\n",
    "first_batch = next(single_dataset.as_numpy_iterator())\n",
    "assert first_batch['text_input_ids'].tolist()[0] == [\n",
    "    101,  8554,  8310,   143, 10060,   102]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multi-modal inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# multi modal input\n",
    "mm_input = [{'text': 'this is a test',\n",
    "             'image': np.zeros(shape=(5, 10), dtype='float32')}] * 5\n",
    "mm_dataset = predict_input_fn(\n",
    "    mm_input, params=params)\n",
    "first_batch = next(mm_dataset.as_numpy_iterator())\n",
    "assert first_batch['text_input_ids'].tolist()[0] == [\n",
    "    101,  8554,  8310,   143, 10060,   102]\n",
    "assert first_batch['image_input_ids'].tolist()[0] == np.zeros(\n",
    "    shape=(5, 10), dtype='float32').tolist()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
